#! /bin/env python
#
# Generate dependencies between needs and configuration options, for
# libpok or for the kernel.

import os
import sys
import yaml

def make_variable(name, prefix):
    """Prepend the prefix and an underscore to the name and return
    the upper-case version of it.

    >>> make_variable("memcpy", "pok_config_needs_func")
    'POK_CONFIG_NEEDS_FUNC_MEMCPY'
    """
    return "{}_{}".format(prefix, name).upper()

def expand_dependencies(deps):
    """Given a dictionary of dependencies, return an expanded list
    sorted by lexicographic order.

    >>> expand_dependencies({"needs": ["a", "b"], "config_funcs": ["c"]})
    ['POK_CONFIG_NEEDS_FUNC_C', 'POK_NEEDS_A', 'POK_NEEDS_B']
    """
    vars = []
    for (field, prefix) in [("config_funcs", "pok_config_needs_func"),
                            ("funcs", "pok_needs_function"),
                            ("needs", "pok_needs")]:
        vars += [make_variable(v, prefix) for v in deps.get(field, [])]
    return sorted(vars)

def toposort(graph):
    """Topologically sort all nodes in a graph represented as a dictionary.

    >>> toposort({"a": ["b", "c"], "d": ["a"]})
    ['d', 'a', 'b', 'c']
    >>> toposort({"a": ["b"], "b": ["a"]})
    Traceback (most recent call last):
    ...
    Exception: loop in dependencies through node b
    """
    unvisited = set(graph.keys())
    unvisited = unvisited.union(v for vv in graph.values() for v in vv)
    ordered = []
    inserted = set()
    marked = set()
    for v in reversed(sorted(unvisited)):
        def visit(v):
            if v in inserted: return
            if v in marked:
                raise Exception("loop in dependencies through node {}".format(v))
            marked.add(v)
            for n in reversed(sorted(graph.get(v, []))):
                visit(n)
            inserted.add(v)
            ordered.insert(0, v)
        visit(v)
    return ordered

def generate_define(var):
    """Define a variable to 1 if it is not defined already."""
    return "#ifndef {var}\n#define {var} 1\n#endif // !{var}".format(var=var)

def generate_dependencies(check, vars):
    """Define variables if needed, depending on check."""
    return "#ifdef {check}\n\n{inner}\n\n#endif // {check}".format(check = check,
        inner = "\n\n".join(generate_define(var) for var in vars))

def generate_dependency_blocks(block):
    deps = dict((make_variable(k, "pok_needs"), expand_dependencies(v)) for (k, v) in block.items())
    return "\n\n".join(generate_dependencies(k, deps[k]) for k in toposort(deps) if k in deps)

def generate_output(fd, config, source):
    fd.write("// This file has been automatically generated by gen_dependencies\n")
    fd.write("// from the file {}. Do not make manual modifications there or they\n".format(source))
    fd.write("// will be lost.\n\n")

    guard = "__{}_DEPENDENCIES_H__".format(config["domain"].upper())
    fd.write("#ifndef {}\n#define {}\n\n".format(guard, guard))
    if "deployment" in config:
        condition = config["deployment"]["condition"]
        fd.write('{}\n#include "deployment.h"\n{}\n\n'.format(
            condition["before"], condition["after"]
        ))
    if "always" in config:
        for var in sorted(expand_dependencies(config["always"]["dependencies"])):
            fd.write("{}\n\n".format(generate_define(var)))
    if "conditional" in config:
        fd.write("{}\n\n".format(generate_dependency_blocks(config["conditional"]["dependencies"])))
    if "unoptimized" in config:
        nguard = config["unoptimized"]["nguard"]
        fd.write("#ifndef {}\n\n".format(nguard))
        for var in sorted(expand_dependencies(config["unoptimized"]["dependencies"])):
            fd.write("{}\n\n".format(generate_define(var)))
        fd.write("#endif // {}\n\n".format(nguard))
    fd.write("#endif // {}\n".format(guard))

def main():
    if sys.argv[1] == "--test":
        import doctest
        doctest.testmod()
        return
    config = yaml.safe_load(open(sys.argv[1]).read())
    source = os.path.basename(sys.argv[1])
    generate_output(os.popen("clang-format > {}".format(sys.argv[2]), "w"), config, source)

if __name__ == '__main__':
    main()